from __future__ import division, print_function, absolute_import

import numpy as np
from matplotlib import pyplot as plt
from CDTools.tools.cmath import *
from CDTools.tools.analysis import *
from CDTools.tools.interactions import *
from CDTools.tools.image_processing import *
import CDTools
import torch as t
import pickle
from RPI_tools import *
from scipy import fftpack as ffts
from scipy.optimize import minimize


#
# This file runs all the RIP analysis from the x-ray experiment
#

# Set to False if you don't have a GPU
GPU=True

# Set the filenames of the various ptycho/RIP datasets and results
# All ptycho datasets have had the single shot images removed
Siemens_ptycho_results = 'Siemens_ptycho.pickle'
Siemens_ptycho_filename = 'Siemens_ptycho.cxi'
Siemens_ss_filename = 'Siemens_ss.cxi'
FeGd_ptycho_results = 'FeGd_ptycho.pickle'
FeGd_ptycho_filename = 'FeGd_ptycho.cxi'
FeGd_ss_filename = 'FeGd_ss.cxi'

# We start by loading the calibration ptychography data

with open('data/' + Siemens_ptycho_results, 'rb') as f:
    Siemens_ptycho = pickle.load(f)

with open('data/' + FeGd_ptycho_results, 'rb') as f:
    FeGd_ptycho = pickle.load(f)

basis = Siemens_ptycho['basis']

Siemens_probe = Siemens_ptycho['probe'][0]
Siemens_obj = Siemens_ptycho['obj']
Siemens_background = Siemens_ptycho['background']

FeGd_probe = FeGd_ptycho['probe'][0]
FeGd_obj = FeGd_ptycho['obj']
FeGd_background = FeGd_ptycho['background']

# And we plot the major results from the calibration
CDTools.tools.plotting.plot_amplitude(FeGd_obj,basis=basis)
plt.title('FeGd Ptychography Amplitude')
CDTools.tools.plotting.plot_amplitude(Siemens_obj,basis=basis)
plt.title('Siemens Ptychography Amplitude')
CDTools.tools.plotting.plot_phase(Siemens_obj,basis=basis)
plt.title('Siemens Ptychography Phase')
CDTools.tools.plotting.plot_colorized(Siemens_probe,basis=basis)
plt.title('Reconstructed Probe from Siemens Star')
CDTools.tools.plotting.plot_colorized(FeGd_probe,basis=basis)
plt.title('Reconstructed Probe from FeGd')

nogold = np.mean(Siemens_obj[602:607,613:618])
gold = np.mean(Siemens_obj[615:620, 600:605])
atten = gold/nogold

def predicted_atten(thickness):
    wavelen = 1.754
    n = 1 - 0.00334207434 - 0.00242176442j # from henke.lbl.gov
    return np.exp(-2j*np.pi*thickness/wavelen*(n-1))


best_fit_thickness = minimize(lambda thickness: np.abs(predicted_atten(thickness)-atten)**2, 200).x[0]

print('Au Siemens Star Attenuation:', np.abs(atten))
print('Au Siemens Star Phase Shift:', np.angle(atten))
print('Best Fit Au Thickness:', best_fit_thickness)
#plt.show()
#exit()

# Now we load the relevant ptychography and single-shot datasets
Siemens_ptycho_dataset = CDTools.datasets.Ptycho2DDataset.from_cxi('data/' + Siemens_ptycho_filename)
FeGd_ptycho_dataset = CDTools.datasets.Ptycho2DDataset.from_cxi('data/' + FeGd_ptycho_filename)

#Siemens_ss_dataset =CDTools.datasets.Ptycho2DDataset.from_cxi('data/' + Siemens_ptycho_filename)
#plt.close('all')
#Siemens_ss_dataset.inspect()
#plt.show()
#exit()
#Siemens_ss_dataset.patterns = Siemens_ss_dataset.patterns[2192:2193]
#Siemens_ss_dataset.translations = Siemens_ss_dataset.translations[2192:2193]

Siemens_ss_dataset = CDTools.datasets.Ptycho2DDataset.from_cxi('data/' + Siemens_ss_filename)
FeGd_ss_dataset = CDTools.datasets.Ptycho2DDataset.from_cxi('data/' + FeGd_ss_filename)


# We then mask off the central zeroth order peak from the single shot datasets
# This mask is respected by the reconstruction algorithm
Siemens_ss_dataset.mask[120:140,120:145] = False
FeGd_ss_dataset.mask[120:140,120:145] = False

wavelength = Siemens_ss_dataset.wavelength

# And we plot an example pattern
plt.figure()
plt.imshow(np.sqrt(FeGd_ss_dataset.patterns[0]))
plt.title('Example Diffraction Pattern From FeGd')
plt.colorbar()
plt.figure()
plt.imshow(np.sqrt(Siemens_ss_dataset.patterns[0]))
plt.title('Example Diffraction Pattern From Siemens Star')
plt.colorbar()


Siemens_intensities = Siemens_ss_dataset[0][1]
FeGd_intensities = FeGd_ss_dataset[0][1]

# Now we convert the relevant darta to pytorch
t_Siemens_probe = complex_to_torch(Siemens_probe).to(dtype=t.float32) 
t_Siemens_intensities = t.Tensor(Siemens_intensities).to(dtype=t.float32)
t_Siemens_background = t.Tensor(Siemens_background).to(dtype=t.float32)
t_FeGd_probe = complex_to_torch(FeGd_probe).to(dtype=t.float32) 
t_FeGd_intensities = t.Tensor(FeGd_intensities).to(dtype=t.float32)
t_FeGd_background = t.Tensor(FeGd_background).to(dtype=t.float32)


# We run the two retrievals that use calibrations from the same sample
resolution = 70
Siemens_retrieved, Siemens_loss = reconstruct(t_Siemens_intensities, t_Siemens_probe, resolution, 0.4, 1000, background=t_Siemens_background, optimizer='Adam', GPU=GPU, schedule=True, numtries=25, mask=Siemens_ss_dataset.mask)
Siemens_retrieval = torch_to_complex(Siemens_retrieved)

FeGd_retrieved, FeGd_loss = reconstruct(t_FeGd_intensities, t_FeGd_probe, resolution, 0.4, 1000, background=t_FeGd_background, optimizer='Adam', GPU=GPU, schedule=True, numtries=25, mask=FeGd_ss_dataset.mask)
FeGd_retrieval = torch_to_complex(FeGd_retrieved)


# And now we run the defocus ensemble for the cross-sample attempt
defocuses = np.linspace(-50e-6,-30e-6,41)
FeGd_Siemens_retrieved, FeGd_Siemens_losses = reconstruct_defocus(t_FeGd_intensities, t_Siemens_probe, basis, wavelength, resolution, 0.4, 1000, defocuses, background=t_FeGd_background, optimizer='Adam', GPU=GPU, schedule=True, numtries=25, mask=FeGd_ss_dataset.mask)
FeGd_Siemens_retrieval = torch_to_complex(FeGd_Siemens_retrieved)
pixel_ratio = Siemens_probe.shape[0] / resolution
big_basis = basis * pixel_ratio


# And we plot the relevant results
CDTools.tools.plotting.plot_amplitude(Siemens_obj,basis=basis)
plt.title('Siemens Star Ptychography Amplitude')
CDTools.tools.plotting.plot_amplitude(Siemens_retrieval,basis=big_basis)
plt.title('Retrieved Siemens Star Amplitude')
CDTools.tools.plotting.plot_phase(Siemens_retrieval,basis=big_basis)
plt.title('Retrieved Siemens Star Phase')
CDTools.tools.plotting.plot_amplitude(FeGd_obj,basis=basis)
plt.title('FeGd Ptychography Amplitude')
CDTools.tools.plotting.plot_amplitude(FeGd_retrieval,basis=big_basis)
plt.title('Retrieved FeGd Amplitude')
CDTools.tools.plotting.plot_phase(FeGd_retrieval,basis=big_basis)
plt.title('Retrieved FeGd Phase')
CDTools.tools.plotting.plot_amplitude(FeGd_Siemens_retrieval,basis=big_basis)
plt.title('Retrieved FeGd Amplitude Using Siemens Star Probe')
CDTools.tools.plotting.plot_phase(FeGd_Siemens_retrieval,basis=big_basis)
plt.title('Retrieved FeGd Phase Using Siemens Star Probe')
plt.figure()
plt.plot(defocuses*1e6, FeGd_Siemens_losses)
plt.xlabel('Probe Defocus (um)')
plt.ylabel('RMS Diffraction Error')


#
#
# Now we calculate the FRCs between the various images.
# This calculation requires extracting specific portions of images
# and so is rather tangled, but all boils down to extracting the
# correct region of a ptychography scan and downsampling it appropriately.
#
#


#fov = 30
#offset_FeGd_ss = [20,20]
#offset_Siemens_ss = [20,20]

fov = 40
offset_FeGd_ss = [15,15]
offset_Siemens_ss = [15,15]


# We calculate the portion of the ptychography result which should overlap
# with our single-shot reconstruction. These models mirror the models
# used for reconstruction and should thus have the same geometry
Siemens_ptycho_model = CDTools.models.FancyPtycho.from_dataset(Siemens_ptycho_dataset, padding=0, auto_center=False)
offset_Siemens_ptycho = translations_to_pixel(Siemens_ptycho_model.probe_basis,
                                              Siemens_ss_dataset.translations)[0]
offset_Siemens_ptycho -= Siemens_ptycho_model.min_translation
offset_Siemens_ptycho = offset_Siemens_ptycho.numpy() +  pixel_ratio * np.array(offset_Siemens_ss)
offset_Siemens_ptycho = offset_Siemens_ptycho.astype(np.int)


FeGd_ptycho_model = CDTools.models.FancyPtycho.from_dataset(FeGd_ptycho_dataset, padding=0, auto_center=False)
offset_FeGd_ptycho = translations_to_pixel(FeGd_ptycho_model.probe_basis,
                                           FeGd_ss_dataset.translations)[0]
offset_FeGd_ptycho -= FeGd_ptycho_model.min_translation
offset_FeGd_ptycho = offset_FeGd_ptycho.numpy() +  pixel_ratio * np.array(offset_FeGd_ss)
offset_FeGd_ptycho = offset_FeGd_ptycho.astype(np.int)


# Now we crop out a region of the Siemens Star data
big_fov = int(pixel_ratio * fov)
pad = (big_fov - fov)//2

big_xs = np.arange(big_fov)
big_Xs, big_Ys = np.meshgrid(big_xs, big_xs, indexing='ij')
big_Xhann = np.sin(np.pi*big_Xs/(big_fov-1))**2
big_Yhann = np.sin(np.pi*big_Ys/(big_fov-1))**2
big_window = (big_Xhann * big_Yhann).astype(np.float32)

lil_xs = np.arange(fov)
lil_Xs, lil_Ys = np.meshgrid(lil_xs, lil_xs, indexing='ij')
lil_Xhann = np.sin(np.pi*lil_Xs/(fov-1))**2
lil_Yhann = np.sin(np.pi*lil_Ys/(fov-1))**2
lil_window = (lil_Xhann * lil_Yhann).astype(np.float32)

#plt.close('all')
#plt.imshow(big_window)
#plt.figure()
#plt.imshow(lil_window)
#plt.show()



Siemens_ss_cropped = lil_window*Siemens_retrieval[offset_Siemens_ss[0]:offset_Siemens_ss[0]+fov,
                                       offset_Siemens_ss[1]:offset_Siemens_ss[1]+fov].copy()

Siemens_ptycho_cropped =big_window*Siemens_obj[offset_Siemens_ptycho[0]:offset_Siemens_ptycho[0]+big_fov,
                                     offset_Siemens_ptycho[1]:offset_Siemens_ptycho[1]+big_fov].copy()
Siemens_ptycho_fft = ffts.fftshift(np.fft.fft2(ffts.ifftshift(Siemens_ptycho_cropped)))
Siemens_ptycho_cropped = ffts.fftshift(np.fft.ifft2(ffts.ifftshift(Siemens_ptycho_fft[pad:pad+fov,
                                                                                      pad:pad+fov])))


#Duh, I need to crop AFTER I downsample to avoid artifacts from the edge.
# Still not done.

# We then align it using a shift-finding algorithm and a subpixel shift
Siemens_shift = find_shift(complex_to_torch(Siemens_ss_cropped),
                           complex_to_torch(Siemens_ptycho_cropped))
Siemens_shifted = torch_to_complex(sinc_subpixel_shift(complex_to_torch(Siemens_ptycho_cropped), Siemens_shift))

# what if  we do no shift?
#Siemens_shifted = Siemens_ptycho_cropped


# We then do the same with the FeGd data
FeGd_ss_cropped = lil_window * FeGd_retrieval[offset_FeGd_ss[0]:offset_FeGd_ss[0]+fov,
                                 offset_FeGd_ss[1]:offset_FeGd_ss[1]+fov].copy()
FeGd_ptycho_cropped = big_window * FeGd_obj[offset_FeGd_ptycho[0]:offset_FeGd_ptycho[0]+big_fov,
                                     offset_FeGd_ptycho[1]:offset_FeGd_ptycho[1]+big_fov].copy()
FeGd_ptycho_fft = ffts.fftshift(np.fft.fft2(ffts.ifftshift(FeGd_ptycho_cropped)))
FeGd_ptycho_cropped = ffts.fftshift(np.fft.ifft2(ffts.ifftshift(FeGd_ptycho_fft[pad:pad+fov,
                                                                                      pad:pad+fov])))

# And shift it
FeGd_shift = find_shift(complex_to_torch(FeGd_ss_cropped),
                           complex_to_torch(FeGd_ptycho_cropped))
FeGd_shifted = torch_to_complex(sinc_subpixel_shift(complex_to_torch(FeGd_ptycho_cropped), FeGd_shift))
#FeGd_shifted = FeGd_ptycho_cropped

# And one final time with the FeGd data which use the Siemens Star probe
FeGd_Siemens_ss_cropped = lil_window * FeGd_Siemens_retrieval[offset_FeGd_ss[0]:offset_FeGd_ss[0]+fov,
                                         offset_FeGd_ss[1]:offset_FeGd_ss[1]+fov].copy()
# We manually correct for the phase ramp
ys, xs = np.mgrid[:fov,:fov]
FeGd_Siemens_ss_cropped *= np.exp(-0.037j * xs) * np.exp(0.073j * ys)


FeGd_Siemens_shift = find_shift(complex_to_torch(FeGd_Siemens_ss_cropped),
                           complex_to_torch(FeGd_ptycho_cropped))
FeGd_Siemens_shifted = torch_to_complex(sinc_subpixel_shift(complex_to_torch(FeGd_ptycho_cropped), FeGd_Siemens_shift))
#FeGd_Siemens_shifted  = FeGd_ptycho_cropped



# Now that we have aligned images, we can calculate FRCs. Here, we remove
# an outer 3 pixelx3 pixel region to avoid circular shifting artifacts
freqs, Siemens_frc, threshold = calc_frc(Siemens_ss_cropped,Siemens_shifted, big_basis, im_slice=np.s_[:,:], snr=1, nbins=20, window='None')
freqs, FeGd_frc, threshold = calc_frc(FeGd_ss_cropped,FeGd_shifted, big_basis, im_slice=np.s_[:,:], snr=1, nbins=20,window='None')
freqs, FeGd_Siemens_frc, threshold = calc_frc(FeGd_Siemens_ss_cropped,FeGd_Siemens_shifted, big_basis, im_slice=np.s_[:,:], snr=1, nbins=20, window='None')


# And we plot the results
plt.figure(figsize=(3.5,3))
plt.plot(freqs*1e6, Siemens_frc, label='Siemens')
plt.plot(freqs*1e6, FeGd_frc, label='FeGd 1')
plt.plot(freqs*1e6, FeGd_Siemens_frc, label='FeGd 2')
plt.plot(freqs*1e6, threshold, label='Threshold')
plt.legend()
plt.tight_layout()

plt.show()

